Select

根据条件张量逐元素选择输入值。对于每个输出位置,如果条件为真(True),则选择 input0 的值;否则选择 input1 的值。该算子支持广播机制。

\[\begin{split}\text{output}_i = \begin{cases} \text{input0}[idx2], & \text{if } \text{condition}[idx1] = \text{True} \\ \text{input1}[idx3], & \text{if } \text{condition}[idx1] = \text{False} \end{cases}\end{split}\]

其中,当不需要广播时(is_broadcast = 0),idx1 = idx2 = idx3 = i;当需要广播时(is_broadcast = 1),使用索引映射 index_list1index_list2index_list3 来确定各个输入张量的索引。

输入:
  • input0 - 第一个输入数据地址。当条件为真时选择此值。

  • input1 - 第二个输入数据地址。当条件为假时选择此值。

  • condition - 条件数据地址(bool类型)。决定选择哪个输入的值。

  • output_dims - 输出张量的维度信息数组。

  • output_dims_num - 输出张量的维度数。

  • index_list1 - 条件张量的索引映射数组,用于广播场景。大小为输出总元素数。

  • index_list2 - input0 的索引映射数组,用于广播场景。大小为输出总元素数。

  • index_list3 - input1 的索引映射数组,用于广播场景。大小为输出总元素数。

  • is_broadcast - 是否需要广播的标志。0 表示不需要广播,1 表示需要广播。

  • core_mask - 核掩码(仅共享存储版本需要)。

输出:
  • output - 输出数据地址,其形状由 output_dimsoutput_dims_num 确定。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持fp32, int8, int16, int32, fp64, cplx64, cplx128

  • MT7004 支持fp16, fp32, int16, int32, cplx64

共享存储版本:

void i8_select_s(int8_t *input0, int8_t *input1, bool *condition, int8_t *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast, int core_mask)
void i16_select_s(int16_t *input0, int16_t *input1, bool *condition, int16_t *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast, int core_mask)
void i32_select_s(int32_t *input0, int32_t *input1, bool *condition, int32_t *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast, int core_mask)
void hp_select_s(half *input0, half *input1, bool *condition, half *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast, int core_mask)
void fp_select_s(float *input0, float *input1, bool *condition, float *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast, int core_mask)
void dp_select_s(double *input0, double *input1, bool *condition, double *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast, int core_mask)
void c64_select_s(float *input0, float *input1, bool *condition, float *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast, int core_mask)
void c128_select_s(double *input0, double *input1, bool *condition, double *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast, int core_mask)

C调用示例(无广播):

 1//FT78NE示例
 2#include <stdio.h>
 3#include <select.h>
 4
 5int main(int argc, char* argv[]) {
 6    // 假设在DDR空间
 7    float *input0 = (float *)0xA0000000;
 8    float *input1 = (float *)0xA1000000;
 9    bool *condition = (bool *)0xA2000000;
10    float *output = (float *)0xB0000000;
11
12    // 输出形状 [2, 3, 4]
13    unsigned long long output_dims[] = {2, 3, 4};
14    unsigned long long output_dims_num = 3;
15
16    // 计算总元素数
17    unsigned long long total_elements = 2 * 3 * 4; // 24
18
19    // 索引映射数组(无广播时可以为NULL或与输出索引相同)
20    unsigned long long *index_list1 = (unsigned long long *)0xC0000000;
21    unsigned long long *index_list2 = (unsigned long long *)0xC0100000;
22    unsigned long long *index_list3 = (unsigned long long *)0xC0200000;
23
24    // 初始化索引映射(无广播时直接使用顺序索引)
25    for (unsigned long long i = 0; i < total_elements; i++) {
26        index_list1[i] = i;
27        index_list2[i] = i;
28        index_list3[i] = i;
29    }
30
31    long long is_broadcast = 0;  // 不需要广播
32    int core_mask = 0xff;
33
34    fp_select_s(input0, input1, condition, output, output_dims, output_dims_num,
35               index_list1, index_list2, index_list3, is_broadcast, core_mask);
36    return 0;
37}

C调用示例(有广播):

 1//FT78NE示例
 2#include <stdio.h>
 3#include <select.h>
 4
 5int main(int argc, char* argv[]) {
 6    // 假设在DDR空间
 7    float *input0 = (float *)0xA0000000;  // 形状 [3, 4]
 8    float *input1 = (float *)0xA1000000;   // 标量或形状 [1]
 9    bool *condition = (bool *)0xA2000000;  // 形状 [2, 3, 4]
10    float *output = (float *)0xB0000000;   // 形状 [2, 3, 4]
11
12    // 输出形状 [2, 3, 4]
13    unsigned long long output_dims[] = {2, 3, 4};
14    unsigned long long output_dims_num = 3;
15
16    // 计算总元素数
17    unsigned long long total_elements = 2 * 3 * 4; // 24
18
19    // 索引映射数组(需要根据广播规则预先计算)
20    unsigned long long *index_list1 = (unsigned long long *)0xC0000000;
21    unsigned long long *index_list2 = (unsigned long long *)0xC0100000;
22    unsigned long long *index_list3 = (unsigned long long *)0xC0200000;
23
24    // 初始化索引映射(示例:需要根据实际广播规则计算)
25    // 这里假设 condition 和 output 形状相同,input0 需要广播
26    for (unsigned long long i = 0; i < total_elements; i++) {
27        index_list1[i] = i;  // condition 索引
28        index_list2[i] = i % 12;  // input0 索引(假设需要广播)
29        index_list3[i] = 0;  // input1 索引(标量)
30    }
31
32    long long is_broadcast = 1;  // 需要广播
33    int core_mask = 0xff;
34
35    fp_select_s(input0, input1, condition, output, output_dims, output_dims_num,
36               index_list1, index_list2, index_list3, is_broadcast, core_mask);
37    return 0;
38}

私有存储版本:

void i8_select_p(int8_t *input0, int8_t *input1, bool *condition, int8_t *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast)
void i16_select_p(int16_t *input0, int16_t *input1, bool *condition, int16_t *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast)
void i32_select_p(int32_t *input0, int32_t *input1, bool *condition, int32_t *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast)
void hp_select_p(half *input0, half *input1, bool *condition, half *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast)
void fp_select_p(float *input0, float *input1, bool *condition, float *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast)
void dp_select_p(double *input0, double *input1, bool *condition, double *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast)
void c64_select_p(float *input0, float *input1, bool *condition, float *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast)
void c128_select_p(double *input0, double *input1, bool *condition, double *output, unsigned long long *output_dims, unsigned long long output_dims_num, unsigned long long *index_list1, unsigned long long *index_list2, unsigned long long *index_list3, long long is_broadcast)

C调用示例(私有存储版本):

 1//FT78NE示例
 2#include <stdio.h>
 3#include <select.h>
 4
 5int main(int argc, char* argv[]) {
 6    // 假设在L2空间
 7    float *input0 = (float *)0x10000000;
 8    float *input1 = (float *)0x10001000;
 9    bool *condition = (bool *)0x10002000;
10    float *output = (float *)0x10003000;
11
12    // 输出形状 [2, 3, 4]
13    unsigned long long output_dims[] = {2, 3, 4};
14    unsigned long long output_dims_num = 3;
15
16    unsigned long long total_elements = 2 * 3 * 4;
17    unsigned long long *index_list1 = (unsigned long long *)0x10004000;
18    unsigned long long *index_list2 = (unsigned long long *)0x10005000;
19    unsigned long long *index_list3 = (unsigned long long *)0x10006000;
20
21    // 初始化索引映射(无广播)
22    for (unsigned long long i = 0; i < total_elements; i++) {
23        index_list1[i] = i;
24        index_list2[i] = i;
25        index_list3[i] = i;
26    }
27
28    long long is_broadcast = 0;
29
30    fp_select_p(input0, input1, condition, output, output_dims, output_dims_num,
31               index_list1, index_list2, index_list3, is_broadcast);
32    return 0;
33}